{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "logging.basicConfig(level=logging.DEBUG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG:tensorflow:Falling back to TensorFlow client; we recommended you install the Cloud TPU client directly with pip install cloud-tpu-client.\n",
      "DEBUG:h5py._conv:Creating converter from 7 to 5\n",
      "DEBUG:h5py._conv:Creating converter from 5 to 7\n",
      "DEBUG:h5py._conv:Creating converter from 7 to 5\n",
      "DEBUG:h5py._conv:Creating converter from 5 to 7\n",
      "DEBUG:git.cmd:Popen(['git', 'version'], cwd=/home/husein/dev/malaya, universal_newlines=False, shell=None, istream=None)\n",
      "DEBUG:git.cmd:Popen(['git', 'version'], cwd=/home/husein/dev/malaya, universal_newlines=False, shell=None, istream=None)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.28 s, sys: 3.46 s, total: 6.75 s\n",
      "Wall time: 2.43 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "import malaya"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from malaya.torch_model.t5 import T5Diaparser\n",
    "from transformers import T5Tokenizer, T5Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tag2idx = {'PAD': 0,\n",
    " 'X': 1,\n",
    " 'nsubj': 2,\n",
    " 'cop': 3,\n",
    " 'det': 4,\n",
    " 'root': 5,\n",
    " 'nsubj:pass': 6,\n",
    " 'acl': 7,\n",
    " 'case': 8,\n",
    " 'obl': 9,\n",
    " 'flat': 10,\n",
    " 'punct': 11,\n",
    " 'appos': 12,\n",
    " 'amod': 13,\n",
    " 'compound': 14,\n",
    " 'advmod': 15,\n",
    " 'cc': 16,\n",
    " 'obj': 17,\n",
    " 'conj': 18,\n",
    " 'mark': 19,\n",
    " 'advcl': 20,\n",
    " 'nmod': 21,\n",
    " 'nummod': 22,\n",
    " 'dep': 23,\n",
    " 'xcomp': 24,\n",
    " 'ccomp': 25,\n",
    " 'parataxis': 26,\n",
    " 'compound:plur': 27,\n",
    " 'fixed': 28,\n",
    " 'aux': 29,\n",
    " 'csubj': 30,\n",
    " 'iobj': 31,\n",
    " 'csubj:pass': 32}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443\n",
      "DEBUG:urllib3.connectionpool:https://huggingface.co:443 \"HEAD /mesolitica/finetune-translation-t5-tiny-standard-bahasa-cased/resolve/main/config.json HTTP/1.1\" 200 0\n"
     ]
    }
   ],
   "source": [
    "config = T5Config.from_pretrained('mesolitica/finetune-translation-t5-tiny-standard-bahasa-cased')\n",
    "config.num_labels = len(tag2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of T5Diaparser were not initialized from the model checkpoint at mesolitica/finetune-translation-t5-tiny-standard-bahasa-cased and are newly initialized: ['mlp_arc_d.linear.bias', 'mlp_rel_d.linear.bias', 'rel_attn.weight', 'mlp_arc_h.linear.bias', 'mlp_rel_h.linear.weight', 'mlp_rel_h.linear.bias', 'mlp_arc_h.linear.weight', 'mlp_rel_d.linear.weight', 'mlp_arc_d.linear.weight', 'arc_attn.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = T5Diaparser.from_pretrained('mesolitica/finetune-translation-t5-tiny-standard-bahasa-cased',\n",
    "                                   config = config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443\n",
      "DEBUG:urllib3.connectionpool:https://huggingface.co:443 \"HEAD /mesolitica/finetune-translation-t5-tiny-standard-bahasa-cased/resolve/main/spiece.model HTTP/1.1\" 302 0\n"
     ]
    }
   ],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained('mesolitica/finetune-translation-t5-tiny-standard-bahasa-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from malaya.parser.conll import CoNLL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# conll = CoNLL()\n",
    "# conll('gsd-ud-dev.conllu.txt')\n",
    "# conll.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/huseinzol05/malay-dataset/master/parsing/dependency/gsd-ud-dev.conllu.txt\n",
    "groups, temp = [], []\n",
    "with open('gsd-ud-dev.conllu.txt') as fopen:\n",
    "    for l in fopen:\n",
    "        l = l.strip()\n",
    "        if not len(l):\n",
    "            groups.append(temp[2:])\n",
    "            temp = []\n",
    "        else:\n",
    "            temp.append(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_train(group):\n",
    "    texts, arcs, tags, indices = [], [], [], []\n",
    "    for g in group:\n",
    "        splitted = g.split('\\t')\n",
    "        texts.append(splitted[1])\n",
    "        arcs.append(int(splitted[6]))\n",
    "        tags.append(tag2idx[splitted[7]])\n",
    "        indices.append(int(splitted[0]))\n",
    "        \n",
    "    return texts, arcs, tags, indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# X, ARC, Y, MASK = [], [], [], []\n",
    "features = []\n",
    "for g in groups[:4]:\n",
    "    texts, arcs, tags, indices = [1], [0], [0], [0]\n",
    "    text, arc, tag, index = get_train(g)\n",
    "    for i in range(len(text)):\n",
    "        t = tokenizer.encode(text[i], add_special_tokens=False)\n",
    "        texts.extend(t)\n",
    "        arcs.extend([arc[i]] * len(t))\n",
    "        tags.extend([tag[i]] * len(t))\n",
    "        indices.extend([i + 1] * len(t))\n",
    "        \n",
    "#     X.append(texts)\n",
    "#     ARC.append(arcs)\n",
    "#     Y.append(tags)\n",
    "#     MASK.append([1] * len(texts))\n",
    "    \n",
    "    model_inputs = {\n",
    "        'input_ids': texts,\n",
    "        # 'attention_mask': [0] + [1] * (len(texts) - 1),\n",
    "        'attention_mask': [1] * len(texts),\n",
    "        'labels': tags,\n",
    "        'labels_arc': arcs,\n",
    "        'indices': indices\n",
    "    }\n",
    "    features.append(model_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_pad_token_id = -100\n",
    "padding_side = 'right'\n",
    "labels = [feature[\"labels\"] for feature in features]\n",
    "max_label_length = max(len(l) for l in labels)\n",
    "for feature in features:\n",
    "    remainder = [label_pad_token_id] * (max_label_length - len(feature[\"labels\"]))\n",
    "    remainder_ = [0] * (max_label_length - len(feature[\"labels\"]))\n",
    "    feature[\"labels\"] = (\n",
    "        feature[\"labels\"] + remainder if padding_side == \"right\" else remainder + feature[\"labels_tag\"]\n",
    "    )\n",
    "    feature[\"labels_arc\"] = (\n",
    "        feature[\"labels_arc\"] + remainder if padding_side == \"right\" else remainder + feature[\"labels_tag\"]\n",
    "    )\n",
    "    feature[\"indices\"] = (\n",
    "        feature[\"indices\"] + remainder_ if padding_side == \"right\" else remainder + feature[\"indices\"]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "padded = tokenizer.pad(\n",
    "    features,\n",
    "    padding=True,\n",
    "    max_length=None,\n",
    "    pad_to_multiple_of=None,\n",
    "    return_tensors='pt',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 78])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded['attention_mask'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 78])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded['input_ids'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0,  1,  2,  2,  2,  3,  4,  4,  4,  5,  6,  7,  8,  9, 10, 11, 11, 12,\n",
       "         13, 13, 14, 15, 15, 16, 17, 18, 19, 20, 21, 22, 23, 23,  0,  0,  0,  0,\n",
       "          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "          0,  0,  0,  0,  0,  0],\n",
       "        [ 0,  1,  2,  2,  3,  4,  5,  6,  6,  6,  7,  7,  8,  9, 10, 11, 12, 12,\n",
       "         13, 14, 14, 15, 15, 15, 16, 17, 18, 18, 18, 19, 20, 20, 21, 22, 23, 24,\n",
       "         25, 26, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 35,  0,  0,  0,  0,  0,\n",
       "          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "          0,  0,  0,  0,  0,  0],\n",
       "        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 10, 11, 12, 12, 13, 13, 13,\n",
       "         13, 13, 14, 14, 14, 14, 15, 16, 17, 18, 19, 20, 20, 21, 21, 22, 23, 24,\n",
       "         24, 24, 25, 26, 27, 28, 29, 30, 31, 31, 32, 33, 33, 33, 33, 34, 35, 36,\n",
       "         36, 37, 38, 39, 39, 39, 39, 39, 40, 41, 41, 42, 43, 44, 45, 45, 46, 47,\n",
       "         48, 49, 49, 50, 51, 51],\n",
       "        [ 0,  1,  2,  2,  3,  4,  5,  6,  7,  8,  8,  8,  9, 10, 11, 12, 13, 14,\n",
       "         14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
       "          0,  0,  0,  0,  0,  0]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded['indices']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "o = model(**padded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0],\n",
       "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "o.s_arc.argmax(axis = -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(7.3922, grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "o.loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = Adam(model.parameters(),2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor(7.3922, grad_fn=<AddBackward0>)\n",
      "1 tensor(7.1098, grad_fn=<AddBackward0>)\n",
      "2 tensor(6.7844, grad_fn=<AddBackward0>)\n",
      "3 tensor(6.4093, grad_fn=<AddBackward0>)\n",
      "4 tensor(5.9998, grad_fn=<AddBackward0>)\n",
      "5 tensor(5.5765, grad_fn=<AddBackward0>)\n",
      "6 tensor(5.1111, grad_fn=<AddBackward0>)\n",
      "7 tensor(4.7369, grad_fn=<AddBackward0>)\n",
      "8 tensor(4.0739, grad_fn=<AddBackward0>)\n",
      "9 tensor(3.6403, grad_fn=<AddBackward0>)\n",
      "10 tensor(3.1981, grad_fn=<AddBackward0>)\n",
      "11 tensor(2.7378, grad_fn=<AddBackward0>)\n",
      "12 tensor(2.2673, grad_fn=<AddBackward0>)\n",
      "13 tensor(1.8679, grad_fn=<AddBackward0>)\n",
      "14 tensor(1.5290, grad_fn=<AddBackward0>)\n",
      "15 tensor(1.4472, grad_fn=<AddBackward0>)\n",
      "16 tensor(1.0049, grad_fn=<AddBackward0>)\n",
      "17 tensor(0.8490, grad_fn=<AddBackward0>)\n",
      "18 tensor(0.6342, grad_fn=<AddBackward0>)\n",
      "19 tensor(0.4554, grad_fn=<AddBackward0>)\n",
      "20 tensor(0.3242, grad_fn=<AddBackward0>)\n",
      "21 tensor(0.2385, grad_fn=<AddBackward0>)\n",
      "22 tensor(0.1764, grad_fn=<AddBackward0>)\n",
      "23 tensor(0.1269, grad_fn=<AddBackward0>)\n",
      "24 tensor(0.0992, grad_fn=<AddBackward0>)\n",
      "25 tensor(0.0798, grad_fn=<AddBackward0>)\n",
      "26 tensor(0.0585, grad_fn=<AddBackward0>)\n",
      "27 tensor(0.0393, grad_fn=<AddBackward0>)\n",
      "28 tensor(0.0281, grad_fn=<AddBackward0>)\n",
      "29 tensor(0.0208, grad_fn=<AddBackward0>)\n",
      "30 tensor(0.0162, grad_fn=<AddBackward0>)\n",
      "31 tensor(0.0147, grad_fn=<AddBackward0>)\n",
      "32 tensor(0.0096, grad_fn=<AddBackward0>)\n",
      "33 tensor(0.0074, grad_fn=<AddBackward0>)\n",
      "34 tensor(0.0056, grad_fn=<AddBackward0>)\n",
      "35 tensor(0.0096, grad_fn=<AddBackward0>)\n",
      "36 tensor(0.0033, grad_fn=<AddBackward0>)\n",
      "37 tensor(0.0028, grad_fn=<AddBackward0>)\n",
      "38 tensor(0.0023, grad_fn=<AddBackward0>)\n",
      "39 tensor(0.0018, grad_fn=<AddBackward0>)\n",
      "40 tensor(0.0018, grad_fn=<AddBackward0>)\n",
      "41 tensor(0.0027, grad_fn=<AddBackward0>)\n",
      "42 tensor(0.0016, grad_fn=<AddBackward0>)\n",
      "43 tensor(0.0010, grad_fn=<AddBackward0>)\n",
      "44 tensor(0.0009, grad_fn=<AddBackward0>)\n",
      "45 tensor(0.0008, grad_fn=<AddBackward0>)\n",
      "46 tensor(0.0008, grad_fn=<AddBackward0>)\n",
      "47 tensor(0.0007, grad_fn=<AddBackward0>)\n",
      "48 tensor(0.0008, grad_fn=<AddBackward0>)\n",
      "49 tensor(0.0010, grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "for i in range(50):\n",
    "    optimizer.zero_grad()\n",
    "    loss = model(**padded, return_dict = True).loss\n",
    "    print(i, loss)\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "o = model(input_ids = padded['input_ids'], indices = padded['indices'],\n",
    "      attention_mask = padded['attention_mask'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0,  4,  1,  1,  1,  2,  0,  0,  0,  4,  7,  4,  7, 10,  4, 12, 12, 10,\n",
       "         14, 14, 10, 17, 17, 17, 10, 17, 20, 17, 20, 21,  4,  4,  4,  4,  7,  7,\n",
       "          7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,\n",
       "          7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,\n",
       "          7,  7,  7,  7,  7,  7],\n",
       "        [ 0,  3,  1,  1,  0,  5,  3,  5,  5,  5, 28, 28,  9,  3, 11,  9, 11, 11,\n",
       "         12, 28, 28,  3,  3,  3, 17, 15, 17, 17, 17, 18, 18, 18,  3, 23, 21, 23,\n",
       "         24, 28, 28, 28,  3, 28, 31, 29, 31, 32, 33,  3,  3,  3, 31, 31, 31, 31,\n",
       "          3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,\n",
       "          3,  3,  3,  3,  3,  3],\n",
       "        [ 0,  4,  1,  4,  0,  7,  7,  4,  7,  7,  4,  4, 13, 13, 13,  4,  4,  4,\n",
       "          4,  4, 13, 13, 13, 13, 14, 17, 13, 17, 18, 13, 13,  4,  4, 33, 22, 23,\n",
       "         23, 23, 22, 28, 28, 22, 30, 28, 33, 33, 33,  4,  4,  4,  4, 33, 36, 34,\n",
       "         34, 38, 34, 38, 38, 38, 38, 38, 39, 44, 44, 44, 42, 34, 44, 44, 45, 48,\n",
       "         45, 48, 48, 49, 33, 33],\n",
       "        [ 0,  4,  4,  4,  4,  0,  4,  4,  8,  4,  4,  4, 10,  8, 12, 10, 12,  4,\n",
       "          4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,\n",
       "          4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,\n",
       "          4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,\n",
       "          4,  4,  4,  4,  4,  4]])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "o.s_arc.argmax(axis = -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[   0,    4,    1,    1,    1,    2,    0,    0,    0,    4,    7,    4,\n",
       "            7,   10,    4,   12,   12,   10,   14,   14,   10,   17,   17,   17,\n",
       "           10,   17,   20,   17,   20,   21,    4,    4, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100],\n",
       "        [   0,    3,    1,    1,    0,    5,    3,    5,    5,    5,   28,   28,\n",
       "            9,    3,   11,    9,   11,   11,   12,   28,   28,    3,    3,    3,\n",
       "           17,   15,   17,   17,   17,   18,   18,   18,    3,   23,   21,   23,\n",
       "           24,   28,   28,   28,    3,   28,   31,   29,   31,   32,   33,    3,\n",
       "            3, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100],\n",
       "        [   0,    4,    1,    4,    0,    7,    7,    4,    7,    7,    4,    4,\n",
       "           13,   13,   13,    4,    4,    4,    4,    4,   13,   13,   13,   13,\n",
       "           14,   17,   13,   17,   18,   13,   13,    4,    4,   33,   22,   23,\n",
       "           23,   23,   22,   28,   28,   22,   30,   28,   33,   33,   33,    4,\n",
       "            4,    4,    4,   33,   36,   34,   34,   38,   34,   38,   38,   38,\n",
       "           38,   38,   39,   44,   44,   44,   42,   34,   44,   44,   45,   48,\n",
       "           45,   48,   48,   49,   33,   33],\n",
       "        [   0,    4,    4,    4,    4,    0,    4,    4,    8,    4,    4,    4,\n",
       "           10,    8,   12,   10,   12,    4,    4, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100]])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded['labels_arc']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "arc_preds = o.s_arc.argmax(axis = -1)\n",
    "rel_preds = o.s_rel.argmax(-1)\n",
    "rel_preds = rel_preds.gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[   0,    2,   14,   14,   14,   14,    5,    5,    5,   17,    8,    9,\n",
       "           13,    8,    9,   11,   11,   18,   11,   11,   18,   11,   11,   16,\n",
       "           18,   13,    2,    7,   17,   14,   11,   11, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100],\n",
       "        [   0,    2,   10,   10,    5,    8,    9,   13,   13,   13,   11,   11,\n",
       "           15,   13,    8,   21,   10,   10,   10,   11,   11,   13,   13,   13,\n",
       "            8,   21,   10,   10,   10,   10,   11,   11,   13,    8,   21,   10,\n",
       "           10,   11,   11,   16,   18,   15,    8,   21,   10,   10,   10,   11,\n",
       "           11, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100],\n",
       "        [   0,    6,    4,   15,    5,    8,    4,    9,   14,    4,   11,   11,\n",
       "            8,   11,   11,    9,    9,    9,    9,    9,   10,   10,   10,   10,\n",
       "           10,   16,   18,   10,   10,   11,   11,   11,   11,    2,   10,   10,\n",
       "           10,   10,    8,    6,   15,    7,    8,    9,   11,   11,   15,   23,\n",
       "           23,   23,   23,   17,    8,   21,   21,   16,   18,   14,   14,   14,\n",
       "           14,   14,   10,   11,   11,   16,   28,   18,   14,   14,   14,   16,\n",
       "           18,   14,   14,   14,   11,   11],\n",
       "        [   0,   19,    2,    2,    3,    5,   15,   14,    6,    7,    7,    7,\n",
       "            8,    9,   16,   18,    7,   11,   11, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n",
       "         -100, -100, -100, -100, -100, -100]])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded['labels']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0,  2, 14, 14, 14, 14,  5,  5,  5, 17,  8,  9, 13,  8,  9, 11, 11, 18,\n",
       "         11, 11, 18, 11, 11, 16, 18, 13,  2,  7, 17, 14, 11, 11, 17, 17,  8,  8,\n",
       "          8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,\n",
       "          8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,\n",
       "          8,  8,  8,  8,  8,  8],\n",
       "        [ 0,  2, 10, 10,  5,  8,  9, 13, 13, 13, 11, 11, 15, 13,  8, 21, 10, 10,\n",
       "         10, 11, 11, 13, 13, 13,  8, 21, 10, 10, 10, 10, 11, 11, 13,  8, 21, 10,\n",
       "         10, 11, 11, 16, 18, 15,  8, 21, 10, 10, 10, 11, 11, 11, 10, 10, 21, 21,\n",
       "         13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,\n",
       "         13, 13, 13, 13, 13, 13],\n",
       "        [ 0,  6,  4, 15,  5,  8,  4,  9, 14,  4, 11, 11,  8, 11, 11,  9,  9,  9,\n",
       "          9,  9, 10, 10, 10, 10, 10, 16, 18, 10, 10, 11, 11, 11, 11,  2, 10, 10,\n",
       "         10, 10,  8,  6, 15,  7,  8,  9, 11, 11, 15, 23, 23, 23, 23, 17,  8, 21,\n",
       "         21, 16, 18, 14, 14, 14, 14, 14, 10, 11, 11, 16, 28, 18, 14, 14, 14, 16,\n",
       "         18, 14, 14, 14, 11, 11],\n",
       "        [ 0, 19,  2,  2,  3,  5, 15, 14,  6,  7,  7,  7,  8,  9, 16, 18,  7, 11,\n",
       "         11, 11, 11, 11, 11, 11, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,\n",
       "         19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,\n",
       "         19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,\n",
       "         19, 19, 19, 19, 19, 19]])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rel_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
